{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w0OkH2378WXZ"
      },
      "source": [
        "# Overview\n",
        "\n",
        "This colab demonstrates an example of using ViP-DeepLab to output sequence-level\n",
        "depth-aware video panoptic segmentation predictions. It loads an exported ViP-DeepLab model trained for Cityscapes-DVPS and visualizes the outputs for a sequence of the dataset.\n",
        "\n",
        "[1] Marius Cordts, Mohamed Omran, Sebastian Ramos, Timo Rehfeld, Markus Enzweiler, Rodrigo Benenson, Uwe Franke, Stefan Roth, and Bernt Schiele. The cityscapes dataset for semantic urban scene understanding. CVPR, 2016.\n",
        "\n",
        "[2] Dahun Kim, Sanghyun Woo, Joon-Young Lee, and In So Kweon. Video panoptic segmentation. CVPR, 2020.\n",
        "\n",
        "[3] Siyuan Qiao, Yukun Zhu, Hartwig Adam, Alan Yuille, and Liang-Chieh Chen.\n",
        "ViP-DeepLab: Learning Visual Perception with Depth-aware Video Panoptic Segmentation. CVPR, 2021."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rCknVEZw_RWr"
      },
      "source": [
        "# Inputs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GQ4o2FEM9kCP"
      },
      "outputs": [],
      "source": [
        "# MODEL_DIR The directory of the exported ViP-DeepLab model.\n",
        "MODEL_URL = 'https://storage.googleapis.com/gresearch/tf-deeplab/saved_model/resnet50_beta_os32_vip_deeplab_cityscapes_dvps_train_saved_model.tar.gz' #@param {type:\"string\"}\n",
        "\n",
        "# SEQUENCE_PATTERN The file name pattern for the input sequence.\n",
        "SEQUENCE_PATTERN = '' #@param {type:\"string\"}\n",
        "\n",
        "# LABEL_DIVISOR The label divisor for the dataset.\n",
        "LABEL_DIVISOR = 1000 #@param {type:\"integer\"}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SXwvG3cE_Zw3"
      },
      "source": [
        "# ViP-DeepLab Sequence Inference Class"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "26x9jRatkQIF"
      },
      "outputs": [],
      "source": [
        "#@title Import Python Libaries\n",
        "import tensorflow as tf\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import copy\n",
        "import collections\n",
        "import typing\n",
        "import tempfile\n",
        "import urllib"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iM9hgSBJlOCE"
      },
      "outputs": [],
      "source": [
        "#@title Define ViP-DeepLab Sequence Inference Class\n",
        "\n",
        "class ViPDeepLab:\n",
        "  \"\"\"Sequence inference model for ViP-DeepLab.\n",
        "\n",
        "  Frame-level ViP-DeepLab takes two consecutive frames as inputs and generates\n",
        "  temporarily consistent depth-aware video panoptic predictions. Sequence-level\n",
        "  ViP-DeepLab takes a sequence of images as input and propages the instance IDs\n",
        "  between all 2-frame predictions made by frame-level ViP-DeepLab.\n",
        "\n",
        "  Siyuan Qiao, Yukun Zhu, Hartwig Adam, Alan Yuille, and Liang-Chieh Chen.\n",
        "  ViP-DeepLab: Learning Visual Perception with Depth-aware Video Panoptic\n",
        "  Segmentation. CVPR, 2021.\n",
        "  \"\"\"\n",
        "\n",
        "  def __init__(self, model_path: str, label_divisor: int):\n",
        "    \"\"\"Initializes a ViP-DeepLab model.\n",
        "\n",
        "    Args:\n",
        "      model_path: A string specifying the path to the exported ViP-DeepLab.\n",
        "      label_divisor: An integer specifying the dataset label divisor.\n",
        "    \"\"\"\n",
        "    self._model = tf.saved_model.load(model_path)\n",
        "    self._label_divisor = label_divisor\n",
        "    self._overlap_offset = label_divisor // 2\n",
        "    self._combine_offset = 2 ** 32\n",
        "    self.reset()\n",
        "\n",
        "  def reset(self):\n",
        "    \"\"\"Resets the sequence predictions.\"\"\"\n",
        "    self._max_instance_id = 0\n",
        "    self._depth_preds = []\n",
        "    self._stitched_panoptic = []\n",
        "    self._last_panoptic = None\n",
        "\n",
        "  def _infer(self, input_array, next_input_array):\n",
        "    \"\"\"Inference for two consecutive input frames.\"\"\"\n",
        "    input_array = np.concatenate((input_array, next_input_array), axis=-1)\n",
        "    output = self._model(input_array)\n",
        "    depth = output['depth_pred'].numpy()\n",
        "    panoptic = output['panoptic_pred'].numpy()\n",
        "    next_panoptic = output['next_panoptic_pred'].numpy()\n",
        "    return depth, panoptic, next_panoptic\n",
        "\n",
        "  def infer(self, inputs: typing.List[tf.Tensor]):\n",
        "    \"\"\"Inference for a sequence of input frames.\n",
        "\n",
        "    Args:\n",
        "      inputs: A list of tf.Tensor storing the input frames.\n",
        "    \"\"\"\n",
        "    self.reset()\n",
        "    for input_idx in range(len(inputs) - 1):\n",
        "      depth, panoptic, next_panoptic = self._infer(inputs[input_idx],\n",
        "                                                   inputs[input_idx + 1])\n",
        "      self._depth_preds.append(copy.deepcopy(depth))\n",
        "      # Propagate instance ID from last_panoptic to next_panoptic based on ID\n",
        "      # matching between panoptic and last_panoptic. panoptic and last_panoptic\n",
        "      # stores panoptic predictions for the same frame but from different runs.\n",
        "      next_new_mask = next_panoptic % self._label_divisor \u003e self._overlap_offset\n",
        "      if self._last_panoptic is not None:\n",
        "        intersection = (\n",
        "            self._last_panoptic.astype(np.int64) * self._combine_offset +\n",
        "            panoptic.astype(np.int64))\n",
        "        intersection_ids, intersection_counts = np.unique(\n",
        "            intersection, return_counts=True)\n",
        "        intersection_ids = intersection_ids[np.argsort(intersection_counts)]\n",
        "        for intersection_id in intersection_ids:\n",
        "          last_panoptic_id = intersection_id // self._combine_offset\n",
        "          panoptic_id = intersection_id % self._combine_offset\n",
        "          next_panoptic[next_panoptic == panoptic_id] = last_panoptic_id\n",
        "      # Adjust the IDs for the new instances in next_panoptic.\n",
        "      self._max_instance_id = max(self._max_instance_id,\n",
        "                                  np.max(panoptic % self._label_divisor))\n",
        "      next_panoptic_cls = next_panoptic // self._label_divisor\n",
        "      next_panoptic_ins = next_panoptic % self._label_divisor\n",
        "      next_panoptic_ins[next_new_mask] = (\n",
        "          next_panoptic_ins[next_new_mask] - self._overlap_offset\n",
        "          + self._max_instance_id)\n",
        "      next_panoptic = (\n",
        "          next_panoptic_cls * self._label_divisor + next_panoptic_ins)\n",
        "      if not self._stitched_panoptic:\n",
        "        self._stitched_panoptic.append(copy.deepcopy(panoptic))\n",
        "      self._stitched_panoptic.append(copy.deepcopy(next_panoptic))\n",
        "      self._max_instance_id = max(self._max_instance_id,\n",
        "                                  np.max(next_panoptic % self._label_divisor))\n",
        "      self._last_panoptic = copy.deepcopy(next_panoptic)\n",
        "\n",
        "  def results(self):\n",
        "    \"\"\"Returns the sequence inference results.\"\"\"\n",
        "    return self._depth_preds, self._stitched_panoptic"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uPWbEN1g_k8I"
      },
      "source": [
        "# A Sequence Example on Cityscapes-DVPS"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s2jFjQ0wN9VJ"
      },
      "outputs": [],
      "source": [
        "#@title Download pre-trained checkpoint.\n",
        "model_name = 'resnet50_beta_os32_vip_deeplab_cityscapes_dvps_train_saved_model'\n",
        "model_dir = tempfile.mkdtemp()\n",
        "download_path = os.path.join(model_dir, 'model.tar.gz')\n",
        "urllib.request.urlretrieve(MODEL_URL, download_path)\n",
        "!tar -xzvf {download_path} -C {model_dir}\n",
        "model_path = os.path.join(model_dir, model_name, 'exports')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9d0vm8gHlGFg"
      },
      "outputs": [],
      "source": [
        "#@title Run Inference on Examples from Cityscapes-DVPS\n",
        "vip_deeplab = ViPDeepLab(model_path=model_path, label_divisor=LABEL_DIVISOR)\n",
        "filenames = sorted(tf.io.gfile.glob(SEQUENCE_PATTERN))[0:3]\n",
        "inputs = []\n",
        "for filename in filenames:\n",
        "  inputs.append(tf.image.decode_png(tf.io.read_file(filename)))\n",
        "inputs.append(inputs[-1])\n",
        "vip_deeplab.infer(inputs)\n",
        "depth_preds, stitched_panoptic = vip_deeplab.results()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EVr8fIBqKgfR"
      },
      "outputs": [],
      "source": [
        "#@title Visualization Utilities\n",
        "\n",
        "DatasetInfo = collections.namedtuple(\n",
        "    'DatasetInfo',\n",
        "    'num_classes, label_divisor, thing_list, colormap, class_names')\n",
        "\n",
        "\n",
        "def _cityscapes_label_colormap():\n",
        "  \"\"\"Creates a label colormap used in CITYSCAPES segmentation benchmark.\n",
        "\n",
        "  See more about CITYSCAPES dataset at https://www.cityscapes-dataset.com/\n",
        "  M. Cordts, et al. \"The Cityscapes Dataset for Semantic Urban Scene Understanding.\" CVPR. 2016.\n",
        "\n",
        "  Returns:\n",
        "    A 2-D numpy array with each row being mapped RGB color (in uint8 range).\n",
        "  \"\"\"\n",
        "  colormap = np.zeros((256, 3), dtype=np.uint8)\n",
        "  colormap[0] = [128, 64, 128]\n",
        "  colormap[1] = [244, 35, 232]\n",
        "  colormap[2] = [70, 70, 70]\n",
        "  colormap[3] = [102, 102, 156]\n",
        "  colormap[4] = [190, 153, 153]\n",
        "  colormap[5] = [153, 153, 153]\n",
        "  colormap[6] = [250, 170, 30]\n",
        "  colormap[7] = [220, 220, 0]\n",
        "  colormap[8] = [107, 142, 35]\n",
        "  colormap[9] = [152, 251, 152]\n",
        "  colormap[10] = [70, 130, 180]\n",
        "  colormap[11] = [220, 20, 60]\n",
        "  colormap[12] = [255, 0, 0]\n",
        "  colormap[13] = [0, 0, 142]\n",
        "  colormap[14] = [0, 0, 70]\n",
        "  colormap[15] = [0, 60, 100]\n",
        "  colormap[16] = [0, 80, 100]\n",
        "  colormap[17] = [0, 0, 230]\n",
        "  colormap[18] = [119, 11, 32]\n",
        "  return colormap\n",
        "\n",
        "\n",
        "def _cityscapes_class_names():\n",
        "  return ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',\n",
        "          'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',\n",
        "          'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',\n",
        "          'bicycle')\n",
        "\n",
        "\n",
        "def cityscapes_dataset_information():\n",
        "  return DatasetInfo(\n",
        "      num_classes=19,\n",
        "      label_divisor=1000,\n",
        "      thing_list=tuple(range(11, 19)),\n",
        "      colormap=_cityscapes_label_colormap(),\n",
        "      class_names=_cityscapes_class_names())\n",
        "\n",
        "\n",
        "def perturb_color(color, noise, used_colors, max_trials=50, random_state=None):\n",
        "  \"\"\"Pertrubs the color with some noise.\n",
        "\n",
        "  If `used_colors` is not None, we will return the color that has\n",
        "  not appeared before in it.\n",
        "\n",
        "  Args:\n",
        "    color: A numpy array with three elements [R, G, B].\n",
        "    noise: Integer, specifying the amount of perturbing noise (in uint8 range).\n",
        "    used_colors: A set, used to keep track of used colors.\n",
        "    max_trials: An integer, maximum trials to generate random color.\n",
        "    random_state: An optional np.random.RandomState. If passed, will be used to\n",
        "      generate random numbers.\n",
        "\n",
        "  Returns:\n",
        "    A perturbed color that has not appeared in used_colors.\n",
        "  \"\"\"\n",
        "  if random_state is None:\n",
        "    random_state = np.random\n",
        "\n",
        "  for _ in range(max_trials):\n",
        "    random_color = color + random_state.randint(\n",
        "        low=-noise, high=noise + 1, size=3)\n",
        "    random_color = np.clip(random_color, 0, 255)\n",
        "\n",
        "    if tuple(random_color) not in used_colors:\n",
        "      used_colors.add(tuple(random_color))\n",
        "      return random_color\n",
        "\n",
        "  print('Max trial reached and duplicate color will be used. Please consider '\n",
        "        'increase noise in `perturb_color()`.')\n",
        "  return random_color\n",
        "\n",
        "\n",
        "def color_panoptic_map(panoptic_prediction,\n",
        "                       dataset_info,\n",
        "                       perturb_noise,\n",
        "                       used_colors,\n",
        "                       color_mapping):\n",
        "  \"\"\"Helper method to colorize output panoptic map.\n",
        "\n",
        "  Args:\n",
        "    panoptic_prediction: A 2D numpy array, panoptic prediction from deeplab\n",
        "      model.\n",
        "    dataset_info: A DatasetInfo object, dataset associated to the model.\n",
        "    perturb_noise: Integer, the amount of noise (in uint8 range) added to each\n",
        "      instance of the same semantic class.\n",
        "    used_colors: A set, used to keep track of used colors.\n",
        "    color_mapping: A dict, used to map exisiting panoptic ids.\n",
        "\n",
        "  Returns:\n",
        "    colored_panoptic_map: A 3D numpy array with last dimension of 3, colored\n",
        "      panoptic prediction map.\n",
        "    used_colors: A dictionary mapping semantic_ids to a set of colors used\n",
        "      in `colored_panoptic_map`.\n",
        "  \"\"\"\n",
        "  if panoptic_prediction.ndim != 2:\n",
        "    raise ValueError('Expect 2-D panoptic prediction. Got {}'.format(\n",
        "        panoptic_prediction.shape))\n",
        "\n",
        "  semantic_map = panoptic_prediction // dataset_info.label_divisor\n",
        "  instance_map = panoptic_prediction % dataset_info.label_divisor\n",
        "  height, width = panoptic_prediction.shape\n",
        "  colored_panoptic_map = np.zeros((height, width, 3), dtype=np.uint8)\n",
        "\n",
        "  # Use a fixed seed to reproduce the same visualization.\n",
        "  random_state = np.random.RandomState(0)\n",
        "\n",
        "  unique_semantic_ids = np.unique(semantic_map)\n",
        "  for semantic_id in unique_semantic_ids:\n",
        "    semantic_mask = semantic_map == semantic_id\n",
        "    if semantic_id in dataset_info.thing_list:\n",
        "      # For `thing` class, we will add a small amount of random noise to its\n",
        "      # correspondingly predefined semantic segmentation colormap.\n",
        "      unique_instance_ids = np.unique(instance_map[semantic_mask])\n",
        "      for instance_id in unique_instance_ids:\n",
        "        instance_mask = np.logical_and(semantic_mask,\n",
        "                                       instance_map == instance_id)\n",
        "        panoptic_id = semantic_id * dataset_info.label_divisor + instance_id\n",
        "        if panoptic_id not in color_mapping:\n",
        "          random_color = perturb_color(\n",
        "              dataset_info.colormap[semantic_id],\n",
        "              perturb_noise,\n",
        "              used_colors[semantic_id],\n",
        "              random_state=random_state)\n",
        "          colored_panoptic_map[instance_mask] = random_color\n",
        "          color_mapping[panoptic_id] = random_color\n",
        "        else:\n",
        "          colored_panoptic_map[instance_mask] = color_mapping[panoptic_id]\n",
        "    else:\n",
        "      # For `stuff` class, we use the defined semantic color.\n",
        "      colored_panoptic_map[semantic_mask] = dataset_info.colormap[semantic_id]\n",
        "      used_colors[semantic_id].add(tuple(dataset_info.colormap[semantic_id]))\n",
        "  return colored_panoptic_map\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NYW1ziM1cumG"
      },
      "outputs": [],
      "source": [
        "#@title Visualize the Predictions\n",
        "used_colors = collections.defaultdict(set)\n",
        "color_mapping = dict()\n",
        "for i in range(len(filenames)):\n",
        "  fig, ax = plt.subplots(1, 3, figsize=(18, 6))\n",
        "  ax[0].title.set_text('Input Image')\n",
        "  ax[0].imshow(np.squeeze(inputs[i]))\n",
        "  ax[1].title.set_text('Depth')\n",
        "  ax[1].imshow(np.squeeze(depth_preds[i]))\n",
        "  panoptic = stitched_panoptic[i]\n",
        "  ax[2].title.set_text('Video Panoptic Segmentation')\n",
        "  panoptic_map = color_panoptic_map(\n",
        "      np.squeeze(panoptic), cityscapes_dataset_information(), 60, used_colors,\n",
        "      color_mapping)\n",
        "  ax[2].imshow(panoptic_map)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "",
        "kind": "local"
      },
      "name": "ViP-DeepLab_Demo.ipynb",
      "provenance": [
        {
          "file_id": "1BUG7E1Skt4ct0KgrvpnDIMzgKu5jSO4I",
          "timestamp": 1645076177182
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
